{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyMPAOeEbrTU6JWjzESEf3uo",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sugarforever/LangChain-Advanced/blob/main/Retrievers/03_Ensemble_Retriever.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xPWZLTgY5E_4"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U langchain openai chromadb tiktoken rank_bm25"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "os.environ['OPENAI_API_KEY'] = \"your valid openai api key\""
      ],
      "metadata": {
        "id": "eVQjKJn-5ZZZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.embeddings import OpenAIEmbeddings\n",
        "from langchain.retrievers import BM25Retriever\n",
        "from langchain.vectorstores import Chroma"
      ],
      "metadata": {
        "id": "7qjgj56XSn5E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import Any, Dict, List\n",
        "\n",
        "from langchain.callbacks.manager import (\n",
        "    AsyncCallbackManagerForRetrieverRun,\n",
        "    CallbackManagerForRetrieverRun,\n",
        ")\n",
        "from langchain.pydantic_v1 import root_validator\n",
        "from langchain.schema import BaseRetriever, Document\n",
        "\n",
        "\n",
        "class EnsembleRetriever(BaseRetriever):\n",
        "    \"\"\"Retriever that ensembles the multiple retrievers.\n",
        "\n",
        "    It uses a rank fusion.\n",
        "\n",
        "    Args:\n",
        "        retrievers: A list of retrievers to ensemble.\n",
        "        weights: A list of weights corresponding to the retrievers. Defaults to equal\n",
        "            weighting for all retrievers.\n",
        "        c: A constant added to the rank, controlling the balance between the importance\n",
        "            of high-ranked items and the consideration given to lower-ranked items.\n",
        "            Default is 60.\n",
        "    \"\"\"\n",
        "\n",
        "    retrievers: List[BaseRetriever]\n",
        "    weights: List[float]\n",
        "    c: int = 60\n",
        "\n",
        "    @root_validator(pre=True)\n",
        "    def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:\n",
        "        if not values.get(\"weights\"):\n",
        "            n_retrievers = len(values[\"retrievers\"])\n",
        "            values[\"weights\"] = [1 / n_retrievers] * n_retrievers\n",
        "        return values\n",
        "\n",
        "    def _get_relevant_documents(\n",
        "        self,\n",
        "        query: str,\n",
        "        *,\n",
        "        run_manager: CallbackManagerForRetrieverRun,\n",
        "    ) -> List[Document]:\n",
        "        \"\"\"\n",
        "        Get the relevant documents for a given query.\n",
        "\n",
        "        Args:\n",
        "            query: The query to search for.\n",
        "\n",
        "        Returns:\n",
        "            A list of reranked documents.\n",
        "        \"\"\"\n",
        "\n",
        "        # Get fused result of the retrievers.\n",
        "        fused_documents = self.rank_fusion(query, run_manager)\n",
        "\n",
        "        return fused_documents\n",
        "\n",
        "    async def _aget_relevant_documents(\n",
        "        self,\n",
        "        query: str,\n",
        "        *,\n",
        "        run_manager: AsyncCallbackManagerForRetrieverRun,\n",
        "    ) -> List[Document]:\n",
        "        \"\"\"\n",
        "        Asynchronously get the relevant documents for a given query.\n",
        "\n",
        "        Args:\n",
        "            query: The query to search for.\n",
        "\n",
        "        Returns:\n",
        "            A list of reranked documents.\n",
        "        \"\"\"\n",
        "\n",
        "        # Get fused result of the retrievers.\n",
        "        fused_documents = await self.arank_fusion(query, run_manager)\n",
        "\n",
        "        return fused_documents\n",
        "\n",
        "    def rank_fusion(\n",
        "        self, query: str, run_manager: CallbackManagerForRetrieverRun\n",
        "    ) -> List[Document]:\n",
        "        \"\"\"\n",
        "        Retrieve the results of the retrievers and use rank_fusion_func to get\n",
        "        the final result.\n",
        "\n",
        "        Args:\n",
        "            query: The query to search for.\n",
        "\n",
        "        Returns:\n",
        "            A list of reranked documents.\n",
        "        \"\"\"\n",
        "\n",
        "        # Get the results of all retrievers.\n",
        "        retriever_docs = [\n",
        "            retriever.get_relevant_documents(\n",
        "                query, callbacks=run_manager.get_child(tag=f\"retriever_{i+1}\")\n",
        "            )\n",
        "            for i, retriever in enumerate(self.retrievers)\n",
        "        ]\n",
        "\n",
        "        # apply rank fusion\n",
        "        fused_documents = self.weighted_reciprocal_rank(retriever_docs)\n",
        "\n",
        "        return fused_documents\n",
        "\n",
        "    async def arank_fusion(\n",
        "        self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun\n",
        "    ) -> List[Document]:\n",
        "        \"\"\"\n",
        "        Asynchronously retrieve the results of the retrievers\n",
        "        and use rank_fusion_func to get the final result.\n",
        "\n",
        "        Args:\n",
        "            query: The query to search for.\n",
        "\n",
        "        Returns:\n",
        "            A list of reranked documents.\n",
        "        \"\"\"\n",
        "\n",
        "        # Get the results of all retrievers.\n",
        "        retriever_docs = [\n",
        "            await retriever.aget_relevant_documents(\n",
        "                query, callbacks=run_manager.get_child(tag=f\"retriever_{i+1}\")\n",
        "            )\n",
        "            for i, retriever in enumerate(self.retrievers)\n",
        "        ]\n",
        "\n",
        "        # apply rank fusion\n",
        "        fused_documents = self.weighted_reciprocal_rank(retriever_docs)\n",
        "\n",
        "        return fused_documents\n",
        "\n",
        "    def weighted_reciprocal_rank(\n",
        "        self, doc_lists: List[List[Document]]\n",
        "    ) -> List[Document]:\n",
        "        \"\"\"\n",
        "        Perform weighted Reciprocal Rank Fusion on multiple rank lists.\n",
        "        You can find more details about RRF here:\n",
        "        https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf\n",
        "\n",
        "        Args:\n",
        "            doc_lists: A list of rank lists, where each rank list contains unique items.\n",
        "\n",
        "        Returns:\n",
        "            list: The final aggregated list of items sorted by their weighted RRF\n",
        "                    scores in descending order.\n",
        "        \"\"\"\n",
        "        if len(doc_lists) != len(self.weights):\n",
        "            raise ValueError(\n",
        "                \"Number of rank lists must be equal to the number of weights.\"\n",
        "            )\n",
        "\n",
        "        # Create a union of all unique documents in the input doc_lists\n",
        "        all_documents = set()\n",
        "        for doc_list in doc_lists:\n",
        "            for doc in doc_list:\n",
        "                all_documents.add(doc.page_content)\n",
        "\n",
        "        # Initialize the RRF score dictionary for each document\n",
        "        rrf_score_dic = {doc: 0.0 for doc in all_documents}\n",
        "\n",
        "        # Calculate RRF scores for each document\n",
        "        for doc_list, weight in zip(doc_lists, self.weights):\n",
        "            for rank, doc in enumerate(doc_list, start=1):\n",
        "                rrf_score = weight * (1 / (rank + self.c))\n",
        "                rrf_score_dic[doc.page_content] += rrf_score\n",
        "\n",
        "        for key, value in rrf_score_dic.items():\n",
        "            print(f'Key: {key}, Value: {value}')\n",
        "\n",
        "        # Sort documents by their RRF scores in descending order\n",
        "        sorted_documents = sorted(\n",
        "            rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True\n",
        "        )\n",
        "\n",
        "        # Map the sorted page_content back to the original document objects\n",
        "        page_content_to_doc_map = {\n",
        "            doc.page_content: doc for doc_list in doc_lists for doc in doc_list\n",
        "        }\n",
        "        sorted_docs = [\n",
        "            page_content_to_doc_map[page_content] for page_content in sorted_documents\n",
        "        ]\n",
        "\n",
        "        return sorted_docs\n"
      ],
      "metadata": {
        "id": "33D_Wq_042Gu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "doc_list = [\n",
        "    \"I like apples\",\n",
        "    \"I like oranges\",\n",
        "    \"Apples and oranges are fruits\"\n",
        "]\n",
        "\n",
        "# initialize the bm25 retriever and Chromadb retriever\n",
        "bm25_retriever = BM25Retriever.from_texts(doc_list)\n",
        "bm25_retriever.k = 2\n",
        "\n",
        "docs = bm25_retriever.get_relevant_documents('apple')\n",
        "print(docs)\n",
        "\n",
        "embedding = OpenAIEmbeddings()\n",
        "vectorstore = Chroma.from_texts(doc_list, embedding, collection_name=\"tutorial_2023\")\n",
        "vs_retriever = vectorstore.as_retriever(search_kwargs={\"k\": 2})\n",
        "\n",
        "vs_docs = vs_retriever.get_relevant_documents('apple')\n",
        "print(vs_docs)\n",
        "\n",
        "# initialize the ensemble retriever\n",
        "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, vs_retriever], weights=[0.5, 0.5])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lafBN9JvSqZi",
        "outputId": "12d16d4b-7cc2-4986-af57-bf274a8a0271"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[Document(page_content='Apples and oranges are fruits'), Document(page_content='I like oranges')]\n",
            "[Document(page_content='I like apples'), Document(page_content='Apples and oranges are fruits')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "docs = ensemble_retriever.get_relevant_documents(\"apple\")\n",
        "docs"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rmftLuuJStBg",
        "outputId": "d91be41d-9f90-4553-8a61-c1267ba06b03"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Key: Apples and oranges are fruits, Value: 0.01626123744050767\n",
            "Key: I like oranges, Value: 0.008064516129032258\n",
            "Key: I like apples, Value: 0.00819672131147541\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Document(page_content='Apples and oranges are fruits'),\n",
              " Document(page_content='I like apples'),\n",
              " Document(page_content='I like oranges')]"
            ]
          },
          "metadata": {},
          "execution_count": 34
        }
      ]
    }
  ]
}